PyTorch Lightningで書き換え
かつて https://nikkie-ftnext.hatenablog.com/entry/lightning-2steps-v10 にまとめた
LightningModule
ネットワークの定義
訓練で呼ばれるメソッドの上書き
例:メトリクス
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#train-epoch-level-operations
Trainer
LightningModuleとそれに渡すデータを扱う
PyTorch LightningのTrainerの仕組み
個々のDataLoaderの例
https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#basic-use
PyTorch LightningのCallback
LightningDataModule(PyTorch LightningのLightningDataModule)
データをまとめられる
以下のコードは https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#using-a-datamodule から
3つを使うとコードが劇的にスッキリする
code:python
dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)